{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0187c2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import pyceres\n",
    "import pycolmap\n",
    "import numpy as np\n",
    "from hloc.utils import viz_3d\n",
    "from copy import deepcopy\n",
    "import cv2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22e87330",
   "metadata": {},
   "source": [
    "## Setup the toy example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ecb0a1b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_rotation(max_=np.pi*2):\n",
    "    aa = np.random.randn(3)\n",
    "    aa *= np.random.rand()*max_ / np.linalg.norm(aa)\n",
    "    R = cv2.Rodrigues(aa)[0]\n",
    "    qvec = pycolmap.rotmat_to_qvec(R)\n",
    "    return qvec\n",
    "\n",
    "def invert(q, t):\n",
    "    return (pycolmap.invert_qvec(q), -pycolmap.qvec_to_rotmat(q).T@t)\n",
    "\n",
    "def error(qt1, qt2):\n",
    "    q, t = pycolmap.relative_pose(*qt1, *qt2)\n",
    "    return (np.linalg.norm(cv2.Rodrigues(pycolmap.qvec_to_rotmat(q))[0]), np.linalg.norm(t))\n",
    "\n",
    "num = 20\n",
    "qt_w_i = [(sample_rotation(), np.random.rand(3)*10) for _ in range(num)]\n",
    "qt_i_w = [invert(q, t) for q, t in qt_w_i]\n",
    "\n",
    "qt_i_j = [pycolmap.relative_pose(*qt_i_w[(i+1)%num], *qt_i_w[i]) for i in range(num)]\n",
    "\n",
    "qt_i_w_init = [(pycolmap.rotmat_to_qvec(pycolmap.qvec_to_rotmat(q)\n",
    "                                        @ pycolmap.qvec_to_rotmat(sample_rotation(np.pi/5))),\n",
    "                t+np.random.randn(3)) for q, t in qt_i_w]\n",
    "qt_i_w_init[0] = qt_i_w[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1d71bcf",
   "metadata": {},
   "source": [
    "## PGO with relative poses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81913517",
   "metadata": {},
   "outputs": [],
   "source": [
    "qt_i_w_opt = deepcopy(qt_i_w_init)\n",
    "\n",
    "prob = pyceres.Problem()\n",
    "loss = pyceres.TrivialLoss()\n",
    "costs = []\n",
    "for i in range(num):\n",
    "    cost = pyceres.factors.PoseGraphRelativeCost(*invert(*qt_i_j[i]), np.eye(6))\n",
    "    costs.append(cost)\n",
    "    prob.add_residual_block(cost, loss, [*qt_i_w_opt[i], *qt_i_w_opt[(i+1)%num]])\n",
    "    prob.set_manifold(qt_i_w_opt[i][0], pyceres.QuaternionManifold())\n",
    "prob.set_parameter_block_constant(qt_i_w_opt[0][0])\n",
    "prob.set_parameter_block_constant(qt_i_w_opt[0][1])\n",
    "\n",
    "options = pyceres.SolverOptions()\n",
    "options.linear_solver_type = pyceres.LinearSolverType.SPARSE_NORMAL_CHOLESKY\n",
    "options.minimizer_progress_to_stdout = False\n",
    "options.num_threads = -1\n",
    "summary = pyceres.SolverSummary()\n",
    "pyceres.solve(options, prob, summary)\n",
    "print(summary.BriefReport())\n",
    "\n",
    "err_init = np.array([error(qt_i_w[i], qt_i_w_init[i]) for i in range(num)])\n",
    "err_opt = np.array([error(qt_i_w[i], qt_i_w_opt[i]) for i in range(num)])\n",
    "print(np.mean(err_opt, 0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4509b114",
   "metadata": {},
   "outputs": [],
   "source": [
    "qt_i_j_init = [pycolmap.relative_pose(*qt_i_w_init[(i+1)%num], *qt_i_w_init[i]) for i in range(num)]\n",
    "for i in range(num):\n",
    "    error_rel_init = error(invert(*qt_i_j[i]), invert(*qt_i_j_init[i]))  # qt_j_init_j\n",
    "    res = costs[i].evaluate(*qt_i_w_init[i], *qt_i_w_init[(i+1)%num])[0]\n",
    "    error_rel_init_ceres = (np.linalg.norm(res[:3]), np.linalg.norm(res[3:]))\n",
    "    assert np.allclose(error_rel_init, error_rel_init_ceres)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b89ed97",
   "metadata": {},
   "source": [
    "# PGO with absolute pose"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4adde3a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "qt_i_w_opt = deepcopy(qt_i_w_init)\n",
    "\n",
    "prob = pyceres.Problem()\n",
    "loss = pyceres.TrivialLoss()\n",
    "costs = []\n",
    "for i in range(num):\n",
    "    cost = pyceres.factors.PoseGraphAbsoluteCost(*qt_i_w[i], np.eye(6))\n",
    "    costs.append(cost)\n",
    "    prob.add_residual_block(cost, loss, [*qt_i_w_opt[i]])\n",
    "    prob.set_manifold(qt_i_w_opt[i][0], pyceres.QuaternionManifold())\n",
    "\n",
    "options = pyceres.SolverOptions()\n",
    "# options.linear_solver_type = pyceres.LinearSolverType.DENSE_QR\n",
    "options.linear_solver_type = pyceres.LinearSolverType.SPARSE_NORMAL_CHOLESKY\n",
    "options.minimizer_progress_to_stdout = False\n",
    "options.num_threads = -1\n",
    "summary = pyceres.SolverSummary()\n",
    "pyceres.solve(options, prob, summary)\n",
    "print(summary.BriefReport())\n",
    "\n",
    "err_init = np.array([error(qt_i_w[i], qt_i_w_init[i]) for i in range(num)])\n",
    "err_opt = np.array([error(qt_i_w[i], qt_i_w_opt[i]) for i in range(num)])\n",
    "print(np.mean(err_opt, 0))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
